{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "87428f48",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "# Custom Layers\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5b2079e0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:10.693752Z",
     "iopub.status.busy": "2023-08-18T19:31:10.693415Z",
     "iopub.status.idle": "2023-08-18T19:31:13.986742Z",
     "shell.execute_reply": "2023-08-18T19:31:13.985398Z"
    },
    "origin_pos": 3,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5cc754f",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "Layers without Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "34c473c6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:14.006461Z",
     "iopub.status.busy": "2023-08-18T19:31:14.005870Z",
     "iopub.status.idle": "2023-08-18T19:31:14.035296Z",
     "shell.execute_reply": "2023-08-18T19:31:14.034301Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2., -1.,  0.,  1.,  2.])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CenteredLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, X):\n",
    "        return X - X.mean()\n",
    "\n",
    "layer = CenteredLayer()\n",
    "layer(torch.tensor([1.0, 2, 3, 4, 5]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "539ec9d8",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Incorporate our layer as a component\n",
    "in constructing more complex models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "370e0abb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:14.048310Z",
     "iopub.status.busy": "2023-08-18T19:31:14.047972Z",
     "iopub.status.idle": "2023-08-18T19:31:14.059041Z",
     "shell.execute_reply": "2023-08-18T19:31:14.057938Z"
    },
    "origin_pos": 20,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-6.5193e-09, grad_fn=<MeanBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential(nn.LazyLinear(128), CenteredLayer())\n",
    "\n",
    "Y = net(torch.rand(4, 8))\n",
    "Y.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1053ec6",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Layers with Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8a664799",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:14.074206Z",
     "iopub.status.busy": "2023-08-18T19:31:14.073211Z",
     "iopub.status.idle": "2023-08-18T19:31:14.080883Z",
     "shell.execute_reply": "2023-08-18T19:31:14.079861Z"
    },
    "origin_pos": 31,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 0.4783,  0.4284, -0.0899],\n",
       "        [-0.6347,  0.2913, -0.0822],\n",
       "        [-0.4325, -0.1645, -0.3274],\n",
       "        [ 1.1898,  0.6482, -1.2384],\n",
       "        [-0.1479,  0.0264, -0.9597]], requires_grad=True)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MyLinear(nn.Module):\n",
    "    def __init__(self, in_units, units):\n",
    "        super().__init__()\n",
    "        self.weight = nn.Parameter(torch.randn(in_units, units))\n",
    "        self.bias = nn.Parameter(torch.randn(units,))\n",
    "\n",
    "    def forward(self, X):\n",
    "        linear = torch.matmul(X, self.weight.data) + self.bias.data\n",
    "        return F.relu(linear)\n",
    "\n",
    "linear = MyLinear(5, 3)\n",
    "linear.weight"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e72f5ed0",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Directly carry out forward propagation calculations using custom layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "859b12e2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:14.084676Z",
     "iopub.status.busy": "2023-08-18T19:31:14.084328Z",
     "iopub.status.idle": "2023-08-18T19:31:14.091968Z",
     "shell.execute_reply": "2023-08-18T19:31:14.090864Z"
    },
    "origin_pos": 36,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.9316, 0.0000],\n",
       "        [0.1808, 1.4208, 0.0000]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "linear(torch.rand(2, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c0aa88c",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "Construct models using custom layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "53f3a28a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:31:14.096505Z",
     "iopub.status.busy": "2023-08-18T19:31:14.095515Z",
     "iopub.status.idle": "2023-08-18T19:31:14.104253Z",
     "shell.execute_reply": "2023-08-18T19:31:14.102782Z"
    },
    "origin_pos": 41,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000],\n",
       "        [13.0800]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))\n",
    "net(torch.rand(2, 64))"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "language_info": {
   "name": "python"
  },
  "required_libs": [],
  "rise": {
   "autolaunch": true,
   "enable_chalkboard": true,
   "overlay": "<div class='my-top-right'><img height=80px src='http://d2l.ai/_static/logo-with-text.png'/></div><div class='my-top-left'></div>",
   "scroll": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}